import os
import numpy as np
import time
from tqdm import tqdm

import torch
import torchvision
import torchvision.transforms as tfs

from sklearn.metrics.cluster import normalized_mutual_info_score, adjusted_mutual_info_score, adjusted_rand_score
from scipy.optimize import linear_sum_assignment as linear_assignment
import argparse
import faiss
from inat_loader import INAT
import pickle

import sys
sys.path.append('../pretraining/dino')
from utils import load_pretrained_weights
import vision_transformer as vits

def _match(flat_preds, flat_targets):
    out_to_gts = {}
    out_to_gts_scores = {}
    for out_c in range(1000):
        for gt_c in range(1000):
            # the amount of out_c at all the gt_c samples
            tp_score = int(((flat_preds == out_c) * (flat_targets == gt_c)).sum())
            if (out_c not in out_to_gts) or (tp_score > out_to_gts_scores[out_c]):
                out_to_gts[out_c] = gt_c
                out_to_gts_scores[out_c] = tp_score
    return list(out_to_gts.items())

def _hungarian_match(flat_preds, flat_targets, preds_k, targets_k):
    num_samples = flat_targets.shape[0]

    num_k = preds_k
    num_correct = np.zeros((num_k, num_k))

    for c1 in range(num_k):
        for c2 in range(num_k):
            # elementwise, so each sample contributes once
            votes = int(((flat_preds == c1) * (flat_targets == c2)).sum())
            num_correct[c1, c2] = votes

    # num_correct is small
    match = linear_assignment(num_samples - num_correct)

    # return as list of tuples, out_c to gt_c
    res = []
    for out_c, gt_c in match:
        res.append((out_c, gt_c))

    return res

def get_kmeans_labels(all_feats, GT_NUM_CLASSES):
    d = all_feats.shape[1]
    # kmeans = faiss.Kmeans(d, ncentroids, niter=niter, gpu=True)

    res = faiss.StandardGpuResources()
    flat_config = faiss.GpuIndexFlatConfig()
    flat_config.device = 0

    kmeans = faiss.Clustering(d, int(GT_NUM_CLASSES))
    kmeans.verbose = True
    kmeans.niter = 50
    kmeans.nredo = 5
    kmeans.seed = 42
    index = faiss.GpuIndexFlatL2(res, d, flat_config)
    kmeans.train(all_feats.astype(np.float32), index)
    dists, pred_labels = index.search(all_feats.astype(np.float32), 1)
    clusters = pred_labels.squeeze()
    return clusters

@torch.no_grad()
def get_features(model, val_loader, ckpt_path, device='cuda'):
    save_path = f"{'/'.join(ckpt_path.split('/')[:-1])}/feats-{dataset_path.replace('/','-')}-{'train' if args.is_train else ''}.pth"

    if (not os.path.exists(save_path) )or (args.recompute):
        print('starting',flush=True)
        for iter, (input, label) in tqdm(enumerate(val_loader), total=len(val_loader)):
            input = input.to(device)
            mass = input.size(0)
            out = model(input)
            if iter == 0:
                all_feats = np.zeros((len(val_loader.dataset), out.size(1)), dtype=np.float32)
                trues = np.zeros(len(val_loader.dataset), dtype=np.int32)
                current = 0
            # all_preds[current:current + mass] = torch.argmax(predictions, 1).cpu().numpy().T.squeeze()
            all_feats[current:current+mass] = out.cpu().numpy()
            trues[current:current+mass] = label
            current += mass
        try:
            pickle.dump({'labels': trues,
                         'feats': all_feats}, open(save_path, 'wb'), protocol=4)
        except:
            print("didnt save")
    else:
        print("using precomputed features at:", save_path, flush=True)
        precomp = pickle.load(open(save_path,'rb'))
        all_feats, trues = precomp['feats'], precomp['labels']
    print("unique classes:", len(np.unique(trues)))
    print('computing K-means...', end='', flush=True)
    start_t = time.time()
    # clusters = KMeans(n_clusters=GT_NUM_CLASSES).fit(all_feats).labels_
    # if use faiss:
    clusters = get_kmeans_labels(all_feats, GT_NUM_CLASSES)

    print(f'done!, took k-means took {(time.time() - start_t) / 60:.2f}min')
    return trues, clusters, all_feats



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
    parser.add_argument('--dataset', help='which dataset', default='IN1k')
    parser.add_argument('--ckpt_path', default='/scratch/shared/beegfs/yuki/adiwol/experiments/mocov2_pretrained/moco_v2_200ep_pretrain.pth.tar',
                        help='path to R50 ckpt')
    parser.add_argument('--recompute', action='store_true',
                        help='recompute')
    parser.add_argument('--is_train', action='store_true',
                        help='recompute')
    parser.add_argument('--deit', action='store_true',
                        help='uses deit model')

    args = parser.parse_args()
    if args.dataset == 'IN1k':
        if os.path.exists('/scratch/local/ssd/datasets/Imagenet/val/'):
            dataset_path = "/scratch/local/ssd/datasets/Imagenet/val/"
        else:
            dataset_path = "/scratch/shared/beegfs/yuki/data/ILSVRC12/val/"
        GT_NUM_CLASSES = 1000
    elif args.dataset == 'objnet':
        dataset_path = "/scratch/shared/beegfs/yuki/data/objnet-r256"
        GT_NUM_CLASSES = 313
    elif args.dataset == 'places':
        if os.path.exists('/scratch/local/ssd/datasets/Places/val/'):
            dataset_path = "/scratch/local/ssd/datasets/Places/val/"
        else:
            dataset_path = "/scratch/shared/beegfs/yuki/data/Places/val/"
        GT_NUM_CLASSES = 205
    elif args.dataset == 'inat18':
        dataset_path = '/datasets/inaturalist2018/images/'
        ann_path = '/datasets/inaturalist2018/val2018.json'
        GT_NUM_CLASSES = 8142
    elif args.dataset =='IN1k-v2':
        dataset_path = '/scratch/shared/beegfs/yuki/fast/imagenetv2-matched-frequency-format-val' #
        GT_NUM_CLASSES = 1000
    elif args.dataset == 'flowers':
        dataset_path = '/scratch/shared/beegfs/yuki/data/102flowers/val'
        GT_NUM_CLASSES = 102
    elif args.dataset == 'user_bal':
        dataset_path = '/scratch/shared/beegfs/yuki/data/user_stratified/val'
        GT_NUM_CLASSES = 3000

    ckpt_path = args.ckpt_path
    do_hungarian = False
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(ckpt_path, dataset_path)

    #* make dataset
    transforms = tfs.Compose([
        tfs.Resize(256),
        tfs.CenterCrop(224),
        tfs.ToTensor(),
        tfs.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    val_loader = torch.utils.data.DataLoader(
        torchvision.datasets.ImageFolder((dataset_path if not args.is_train
                                          else dataset_path.replace('val','train')), transforms) \
            if args.dataset != 'inat18' else INAT(dataset_path, ann_path, args.is_train),
        batch_size=128,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
    )
    print("lenght of dataset:", len(val_loader.dataset), flush=True)

    #* load model
    if not args.deit:
        model = torchvision.models.resnet50(pretrained=False)
        state_dict = torch.load(ckpt_path)
        if "state_dict" in state_dict:
            state_dict = state_dict["state_dict"]
        state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
        if any(['encoder_q' in k for k in state_dict.keys()]):
            print('using moco model')
            state_dict = {k.replace("encoder_q.", ""): v for k, v in state_dict.items() if "encoder_q" in k}
        model.load_state_dict(state_dict, strict=False)
        model.fc = torch.nn.Identity()
    else:
        model = vits.__dict__['deit_small'](patch_size=16, num_classes=0)
        load_pretrained_weights(model, args.ckpt_path, 'teacher' , 'deit_small', 16)

    model.to(device)
    model.eval()

    trues, clusters, all_feats = get_features(model, val_loader, ckpt_path, device)
    preds = clusters

    print('Evaluating...')
    if args.dataset != 'user_bal':
        nmi = normalized_mutual_info_score(trues, preds, average_method='arithmetic')
        anmi = adjusted_mutual_info_score(trues, preds, average_method='arithmetic')
        ari = adjusted_rand_score(trues, preds)
        results = {'NMI':nmi, 'ANMI':anmi, 'ARI':ari}
        print(f"{ckpt_path}: NMI={nmi:.3f} aNMI={anmi:.3f} ARI={ari:.3f}", flush=True)

        if do_hungarian:
            num_samples = len(preds)
            reordered_preds = np.zeros(num_samples, dtype=np.int32)
            match1 = _match(preds, trues)
            for pred_i, target_i in match1:
               reordered_preds[preds == pred_i] = target_i
            acc1 = int((reordered_preds == trues).sum()) / float(num_samples)
            print(f"ACC1:{acc1:.3f}", flush=True)
            results['acc_sort'] = acc1

            if not os.path.exists('match.pth'):
                match2 = _hungarian_match(torch.Tensor(preds).cuda(), torch.Tensor(trues).cuda(), GT_NUM_CLASSES)
                torch.save(match2, 'match.pth')
            else:
                match2 = torch.load('match.pth')

            preds1_cuda = torch.Tensor(preds).cuda()
            reordered_preds2 = torch.zeros(num_samples, dtype=torch.long)
            for pred_i, target_i in match2:
                reordered_preds2[preds == pred_i] = torch.tensor(target_i)
            acc2 = int((reordered_preds2 == torch.Tensor(trues)).sum()) / float(num_samples)
            results['acc_hung'] = acc2
            print(f"ACC_hung:{acc2:.3f}", flush=True)
        torch.save(results,
                   f"{'/'.join(ckpt_path.split('/')[:-1])}-results-{dataset_path.replace('/','-')}.pth")
    if args.is_train:
        torch.save(preds,
                   f"{'/'.join(ckpt_path.split('/')[:-1])}-pseudo-labels-{dataset_path.replace('/','-')}.pth")



